{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import math\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 2.  1.  0. ...  7.  6.  5.]\n",
      " [-6. -8. -9. ... 14. 14. 15.]\n",
      " [ 5.  3.  3. ...  5.  6.  8.]\n",
      " ...\n",
      " [-2. -4. -3. ...  3.  2.  2.]\n",
      " [ 0. -2. -3. ...  8.  8.  8.]\n",
      " [10.  9.  9. ...  9.  6.  3.]]\n"
     ]
    }
   ],
   "source": [
    "## 读取睡眠标签及观察值序列文件hmm_obsv.txt \n",
    "cwd = os.getcwd()\n",
    "HMM_data_path = os.path.join(cwd, 'HMM_data')\n",
    "\n",
    "file =  os.path.join(HMM_data_path, 'hmm_obsv.txt') \n",
    "fid = open(file,\"r\")\n",
    "note = fid.readlines()\n",
    "fid.close()\n",
    "\n",
    "states = [] # 存储hmm_obsv.txt的睡眠标签\n",
    "\n",
    "T = 960 #设置观察序列长度T=2*60*8=960\n",
    "observations_data = np.empty((len(note),T)) #用来存储观察值序列的矩阵，每行为一个观察值序列，每个序列长度固定为T=960\n",
    "Ti = 0\n",
    "\n",
    "for line in note:\n",
    "    line_rr = line.split(\",\")\n",
    "    states.append(line_rr[0])#将睡眠标签写进states数组\n",
    "    line_new = [] # 对应RR间期\n",
    "    for content in line_rr[1:-1]:  #去掉第一位标签，最后一位换行符\n",
    "        line_new.append(int(content)) #int 的作用是将str转为int\n",
    "\n",
    "#     print(line_rr[0])\n",
    "#     print(line_new)\n",
    "#     print(\"==================\")\n",
    "\n",
    "    # 为了保证观察序列长度一致，设置观察序列长度T=2*60*8=960，因此从序列最后开始倒数选取T个观察值\n",
    "    #（why倒数选取T个观察值：因为越后面的RR间期越能反应当前的睡眠分期）\n",
    "    line_new = line_new[-T:]\n",
    "    observations_data[Ti:] = line_new\n",
    "    Ti+=1\n",
    "\n",
    "print(observations_data)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "observations = list(range(-37,38))\n",
    "Dic_observations = {observations[i]:i for i in range(75)}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "正在训练隐马尔科夫模型....\n",
      "[6.40453845e-267 0.00000000e+000 1.00000000e+000 0.00000000e+000\n",
      " 0.00000000e+000 1.69342240e-271]\n",
      "[[1.75492386e-001 1.57923179e-001 1.40174151e-001 1.93017133e-001\n",
      "  1.75470114e-001 1.57923037e-001]\n",
      " [1.84183268e-001 3.68433307e-001 1.57899994e-001 2.84922399e-144\n",
      "  2.63166657e-002 2.63166766e-001]\n",
      " [1.62164521e-001 1.08109688e-001 1.89177342e-001 5.40548439e-002\n",
      "  3.51356496e-001 1.35137110e-001]\n",
      " [2.35294118e-001 6.12603393e-024 1.17647059e-001 1.64130139e-234\n",
      "  5.29411765e-001 1.17647059e-001]\n",
      " [5.47218245e-001 7.13778739e-002 7.16820271e-002 4.76053900e-002\n",
      "  1.43102980e-001 1.19013484e-001]\n",
      " [2.42424266e-001 2.42424279e-001 3.03030333e-001 6.06060666e-002\n",
      "  9.09090998e-002 6.06059552e-002]]\n",
      "-640595.5250371859\n",
      "========\n",
      "[0.00000000e+000 6.93091083e-254 0.00000000e+000 1.00000000e+000\n",
      " 1.02137748e-307 0.00000000e+000]\n",
      "[[5.00000000e-001 1.03126329e-039 0.00000000e+000 3.80482294e-247\n",
      "  5.00000000e-001 0.00000000e+000]\n",
      " [1.21951327e-002 6.21951506e-001 6.09756631e-002 1.21950746e-001\n",
      "  1.58536686e-001 2.43902655e-002]\n",
      " [0.00000000e+000 4.54545455e-002 1.45037517e-155 1.81818182e-001\n",
      "  6.81818182e-001 9.09090909e-002]\n",
      " [0.00000000e+000 1.34321760e-001 5.97014438e-002 6.41792102e-001\n",
      "  8.95578910e-002 7.46268036e-002]\n",
      " [0.00000000e+000 3.57151849e-001 2.61904657e-001 2.38095066e-001\n",
      "  1.42848428e-001 1.36337325e-147]\n",
      " [0.00000000e+000 6.66666667e-001 2.22222222e-001 2.91292540e-065\n",
      "  1.11111110e-001 0.00000000e+000]]\n",
      "-639807.6803311153\n",
      "========\n",
      "[0. 0. 0. 0. 0. 1.]\n",
      "[[4.54545455e-001 9.09090909e-002 1.81818182e-001 9.09090909e-002\n",
      "  0.00000000e+000 1.81818182e-001]\n",
      " [0.00000000e+000 7.70833159e-001 2.50369292e-197 2.08333293e-002\n",
      "  0.00000000e+000 2.08333512e-001]\n",
      " [0.00000000e+000 8.00002278e-002 2.59999999e-001 2.19999998e-001\n",
      "  5.99999994e-002 3.79999777e-001]\n",
      " [1.15384559e-001 3.84615154e-002 2.69230809e-001 1.92307704e-001\n",
      "  0.00000000e+000 3.84615412e-001]\n",
      " [0.00000000e+000 0.00000000e+000 0.00000000e+000 6.66666114e-001\n",
      "  0.00000000e+000 3.33333886e-001]\n",
      " [3.48837442e-002 5.81395429e-002 3.37209340e-001 6.97674499e-002\n",
      "  0.00000000e+000 4.99999923e-001]]\n",
      "-639801.5625337682\n",
      "========\n"
     ]
    }
   ],
   "source": [
    "# 利用Baum-Welch算法求解HMM模型参数(HMM问题二)\n",
    "# 由于Baum-Welch(鲍姆-韦尔奇)算法是基于EM算法的近似算法，所以我们需要多跑几次，比如下面我们跑三次，选择一个比较优的模型参数，代码如下：\n",
    "import numpy as np\n",
    "from hmmlearn import hmm\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "# states = [\"box 1\", \"box 2\", \"box3\"]\n",
    "states = ['1','2','3','4','R','W']\n",
    "n_states = len(states)\n",
    "\n",
    "# observations = [\"red\", \"white\"]\n",
    "observations = list(range(-37,38))\n",
    "Dic_observations = {observations[i]:i for i in range(75)}\n",
    "observations_data_index = np.empty(np.shape(observations_data))\n",
    "for line in range(225):\n",
    "    for colum in range(960):\n",
    "        observations_data_index[line][colum]=Dic_observations[observations_data[line][colum]]\n",
    "\n",
    "n_observations = len(observations)\n",
    "# hmmlearn实现了三种HMM模型类，按照观测状态是连续状态还是离散状态，可以分为两类。\n",
    "# GaussianHMM和GMMHMM是连续观测状态的HMM模型，而MultinomialHMM是离散观测状态的模型，也是我们在HMM原理系列篇里面使用的模型。\n",
    "\n",
    "# model = hmm.MultinomialHMM(n_components=n_states, n_iter=20, tol=0.01) #不满足多项式分布\n",
    "model = hmm.GaussianHMM(n_components=n_states, covariance_type='diag', n_iter=1000, tol=0.01) #本模型中observations为连续状态，满足高斯分布\n",
    "X = observations_data_index\n",
    "\n",
    "print(\"\\n正在训练隐马尔科夫模型....\")  \n",
    "model.fit(X)\n",
    "print(model.startprob_)\n",
    "print(model.transmat_)\n",
    "# print(model.emissionprob_)\n",
    "print(model.score(X))\n",
    "print(\"========\")\n",
    "model.fit(X)\n",
    "print(model.startprob_)\n",
    "print(model.transmat_)\n",
    "# print(model.emissionprob_)\n",
    "print(model.score(X))\n",
    "print(\"========\")\n",
    "model.fit(X)\n",
    "model.fit(X)\n",
    "print(model.startprob_)\n",
    "print(model.transmat_)\n",
    "# print(model.emissionprob_)\n",
    "print(model.score(X))\n",
    "print(\"========\")\n",
    "# 最终我们会选择分数最高的模型参数。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-3040.7282683274534\n"
     ]
    }
   ],
   "source": [
    "# Predict\n",
    "print(model.score([observations_data_index[0]]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[1, 2, 3]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
